# -*-Python-*-
import t5.models.mesh_transformer

eval_checkpoint_step = "all"

# Override these values
utils.run.batch_size = ("tokens_per_replica", 32768)  # Decrease this if OOM errors are encountered

# Plumbing
mesh_eval_dataset_fn.mixture_or_task_name = %MIXTURE_NAME
utils.run.eval_dataset_fn = @t5.models.mesh_transformer.mesh_eval_dataset_fn
utils.run.mode = "eval"

# Setting this lower makes decodes faster
Bitransformer.decode.max_decode_length = 128

utils.run.eval_checkpoint_step = %eval_checkpoint_step

SimdMeshImpl.allreduce_in_bfloat16_max_group_size = 32
